// -*- mode: c++ -*-
// Copyright (c) 2025, Intel Corporation
// SPDX-License-Identifier: BSD-3-Clause

// @file short_vec.isph
// @brief Portion of the ISPC standard library supporting short vector types.

// Unlike stdlib.isph which is included by default, this file is not implicitly
// included. The user must explicitly include it in their ISPC code to use the
// functions defined here. We do this to avoid polluting every single
// compilation with extra processing of quite a few template functions.

#pragma once

// Below the place for the real implementations of the functions using templates

// These are macro generators for template functions that extend element-wise
// standard library functions by adding support for short vectors. They take a
// function name as an argument and help reduce copy-paste. There are two
// versions of the macro: one for uniform and one for varying short vectors.
// The number of function arguments in the macros depends on their suffix,
// e.g., ARG2 indicates that the function takes two arguments.

#define SHORT_VEC_UNIFORM_ARG1(FUNC)                                                                                   \
    template <typename T, uint N> uniform T<N> FUNC(uniform T<N> a) {                                                  \
        uniform T<N> result;                                                                                           \
        foreach (i = 0 ... N) {                                                                                        \
            result[i] = FUNC(a[i]);                                                                                    \
        }                                                                                                              \
        return result;                                                                                                 \
    }

#define SHORT_VEC_VARYING_ARG1(FUNC)                                                                                   \
    template <typename T, uint N> varying T<N> FUNC(varying T<N> a) {                                                  \
        varying T<N> result;                                                                                           \
        for (uniform int i = 0; i < N; i++) {                                                                          \
            result[i] = FUNC(a[i]);                                                                                    \
        }                                                                                                              \
        return result;                                                                                                 \
    }

#define SHORT_VEC_UNIFORM_ARG1_BOOLRET(FUNC)                                                                           \
    template <typename T, uint N> uniform bool<N> FUNC(uniform T<N> a) {                                               \
        uniform bool<N> result;                                                                                        \
        foreach (i = 0 ... N) {                                                                                        \
            result[i] = FUNC(a[i]);                                                                                    \
        }                                                                                                              \
        return result;                                                                                                 \
    }

#define SHORT_VEC_VARYING_ARG1_BOOLRET(FUNC)                                                                           \
    template <typename T, uint N> varying bool<N> FUNC(varying T<N> a) {                                               \
        varying bool<N> result;                                                                                        \
        for (uniform int i = 0; i < N; i++) {                                                                          \
            result[i] = FUNC(a[i]);                                                                                    \
        }                                                                                                              \
        return result;                                                                                                 \
    }

#define SHORT_VEC_UNIFORM_ARG2(FUNC)                                                                                   \
    template <typename T, uint N> uniform T<N> FUNC(uniform T<N> a, uniform T<N> b) {                                  \
        uniform T<N> result;                                                                                           \
        foreach (i = 0 ... N) {                                                                                        \
            result[i] = FUNC(a[i], b[i]);                                                                              \
        }                                                                                                              \
        return result;                                                                                                 \
    }

#define SHORT_VEC_VARYING_ARG2(FUNC)                                                                                   \
    template <typename T, uint N> varying T<N> FUNC(varying T<N> a, varying T<N> b) {                                  \
        varying T<N> result;                                                                                           \
        for (uniform int i = 0; i < N; i++) {                                                                          \
            result[i] = FUNC(a[i], b[i]);                                                                              \
        }                                                                                                              \
        return result;                                                                                                 \
    }

#define SHORT_VEC_UNIFORM_ARG3(FUNC)                                                                                   \
    template <typename T, uint N> uniform T<N> FUNC(uniform T<N> a, uniform T<N> b, uniform T<N> c) {                  \
        uniform T<N> result;                                                                                           \
        foreach (i = 0 ... N) {                                                                                        \
            result[i] = FUNC(a[i], b[i], c[i]);                                                                        \
        }                                                                                                              \
        return result;                                                                                                 \
    }

#define SHORT_VEC_VARYING_ARG3(FUNC)                                                                                   \
    template <typename T, uint N> varying T<N> FUNC(varying T<N> a, varying T<N> b, varying T<N> c) {                  \
        varying T<N> result;                                                                                           \
        for (uniform int i = 0; i < N; i++) {                                                                          \
            result[i] = FUNC(a[i], b[i], c[i]);                                                                        \
        }                                                                                                              \
        return result;                                                                                                 \
    }

#define SHORT_VEC_ARG1(FUNC)                                                                                           \
    SHORT_VEC_UNIFORM_ARG1(FUNC)                                                                                       \
    SHORT_VEC_VARYING_ARG1(FUNC)

#define SHORT_VEC_ARG1_BOOLRET(FUNC)                                                                                   \
    SHORT_VEC_UNIFORM_ARG1_BOOLRET(FUNC)                                                                               \
    SHORT_VEC_VARYING_ARG1_BOOLRET(FUNC)

#define SHORT_VEC_ARG2(FUNC)                                                                                           \
    SHORT_VEC_UNIFORM_ARG2(FUNC)                                                                                       \
    SHORT_VEC_VARYING_ARG2(FUNC)

#define SHORT_VEC_ARG3(FUNC)                                                                                           \
    SHORT_VEC_UNIFORM_ARG3(FUNC)                                                                                       \
    SHORT_VEC_VARYING_ARG3(FUNC)

// Generate the template functions for short vectors

SHORT_VEC_ARG1(abs)
SHORT_VEC_ARG2(min)
SHORT_VEC_ARG2(max)

SHORT_VEC_ARG3(clamp)

SHORT_VEC_ARG1_BOOLRET(isnan)
SHORT_VEC_ARG1_BOOLRET(isinf)
SHORT_VEC_ARG1_BOOLRET(isfinite)

SHORT_VEC_ARG1(round)
SHORT_VEC_ARG1(floor)
SHORT_VEC_ARG1(ceil)
SHORT_VEC_ARG1(trunc)
SHORT_VEC_ARG1(rcp)
SHORT_VEC_ARG1(rcp_fast)

SHORT_VEC_ARG1(sqrt)
SHORT_VEC_ARG1(rsqrt)
SHORT_VEC_ARG1(rsqrt_fast)
SHORT_VEC_ARG1(sin)
SHORT_VEC_ARG1(asin)
SHORT_VEC_ARG1(cos)
SHORT_VEC_ARG1(acos)
SHORT_VEC_ARG1(tan)
SHORT_VEC_ARG1(atan)
SHORT_VEC_ARG1(exp)
SHORT_VEC_ARG1(log)
SHORT_VEC_ARG1(cbrt)

SHORT_VEC_ARG2(atan2)
SHORT_VEC_ARG2(pow)
SHORT_VEC_ARG2(fmod)

// Previous macro are not needed anymore

#undef SHORT_VEC_UNIFORM_ARG1
#undef SHORT_VEC_VARYING_ARG1
#undef SHORT_VEC_UNIFORM_ARG1_BOOLRET
#undef SHORT_VEC_VARYING_ARG1_BOOLRET
#undef SHORT_VEC_UNIFORM_ARG2
#undef SHORT_VEC_VARYING_ARG2
#undef SHORT_VEC_UNIFORM_ARG3
#undef SHORT_VEC_VARYING_ARG3
#undef SHORT_VEC_ARG1
#undef SHORT_VEC_ARG1_BOOLRET
#undef SHORT_VEC_ARG2
#undef SHORT_VEC_ARG3

// Short vector functions with a custom implementation

template <typename T, uint N> inline uniform T<N> select(uniform bool cond, uniform T<N> t, uniform T<N> f) {
    uniform T<N> result;

    foreach (i = 0 ... N) {
        result[i] = select(cond, t[i], f[i]);
    }

    return result;
}

template <typename T, uint N> inline uniform T<N> select(uniform bool<N> cond, uniform T<N> t, uniform T<N> f) {
    uniform T<N> result;

    foreach (i = 0 ... N) {
        result[i] = select(cond[i], t[i], f[i]);
    }

    return result;
}
